{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "25ced253-004d-4d9a-8a20-b6864c02ba7f",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Training a GNN model for user cell prediction and making predictions using transductive inference\n",
    "\n",
    "This notebook goes over how to use Neptune ML to train a GNN model that can be deployed to a machine learning endpoint.\n",
    "The deployed model endpoint can then be used to make predictions with Gremlin Queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dda89f43-076d-4445-a97b-a4693ccc1b88",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import neptune_ml_utils as neptune_ml\n",
    "neptune_ml.check_ml_enabled()\n",
    "\n",
    "\n",
    "s3_bucket_uri=\"<your s3 bucket for model training>\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20a71b23-af17-4a48-9877-c9537354fac1",
   "metadata": {
    "tags": []
   },
   "source": [
    "Before we starting model training, we will perform the following steps\n",
    "- drop some edges : Drop some links between some existing users and Cells to predict them with the transductive mode training.\n",
    "- for user_0 and cell_62000 we dropped ALL edges\n",
    "- for user_1500 and cell_56500 we dropped ALL edges  \n",
    "- for user_4570 three edges user_live_cell_1678350, user_live_cell_734598, user_live_cell_487137"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6123d8d1-a7dd-43c1-a4b6-9ce406cf5e01",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Select users to drop their user_live_cell edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe556883-1da6-4e96-98d4-1b01d727fa30",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.V()\n",
    ".hasId(\"user_0\")\n",
    ".outE()\n",
    ".hasLabel(\"user_live_cell\")\n",
    ".inV()\n",
    ".valueMap(true, \"name\")\n",
    ".groupCount()\n",
    ".unfold()\n",
    ".order()\n",
    ".by(values, desc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdfddf87-df13-4fe4-acb8-ac1e5af66550",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.V('user_0').outE()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5540eba-2088-4c12-9116-a77ebb273d50",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.V('user_0').bothE().where(otherV().hasId('cell_62000'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abe682e6-ef47-4f61-9313-59ddcbeda9d4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#g.V('user_0').bothE().where(otherV().hasId('cell_62000')).drop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "159dc9d0-2d9b-4b78-83cc-fb2f081948bf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.V()\n",
    ".hasId(\"user_1500\")\n",
    ".outE()\n",
    ".hasLabel(\"user_live_cell\")\n",
    ".inV()\n",
    ".valueMap(true, \"name\")\n",
    ".groupCount()\n",
    ".unfold()\n",
    ".order()\n",
    ".by(values, desc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "653b811a-1d06-4437-811a-8731f24d288f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.V('user_1500').bothE().where(otherV().hasId('cell_56500'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fb5fa63-3e41-4707-bb2a-c3fc0dbebce6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.V()\n",
    ".hasId(\"user_1500\")\n",
    ".outE()\n",
    ".hasLabel(\"user_live_cell\")\n",
    ".inV()\n",
    ".valueMap(true, \"name\")\n",
    ".groupCount()\n",
    ".unfold()\n",
    ".order()\n",
    ".by(values, desc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6f2c848-4182-45b7-aa74-90dca1bc26b2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.V('user_4570').bothE().where(otherV().hasId('cell_10570'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e56f57c-b1b7-47ad-8259-1875fcc6f63f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.V('user_4570').outE('user_live_cell').hasId('user_live_cell_734598').drop()"
   ]
  },
  {
   "cell_type": "code",
   "id": "83dacf9a-32a0-43e1-97f5-899d9db9e075",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": [],
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "#g.V('user_4570').outE('user_live_cell').hasId('user_live_cell_1678350').drop()"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35700b4f-877b-4a0f-bd34-93e2b6602a7e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.V('user_4570').bothE().where(otherV().hasId('cell_10570'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6924afe8-9aa7-48a2-84f2-0bb3d2ca309c",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Launch the export job"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c614c7f5-92fd-40a6-a602-ef5ec2191a75",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "export_params={ \n",
    "\"command\": \"export-pg\", \n",
    "\"params\": { \"endpoint\": neptune_ml.get_host(),\n",
    "            \"profile\": \"neptune_ml\",\n",
    "            \"useIamAuth\": neptune_ml.get_iam(),\n",
    "            \"cloneCluster\": False,\n",
    "            \"nodeLabels\": [\"user\", \"cell\"],\n",
    "            \"edgeLabels\": [\"user_live_cell\"]\n",
    "            }, \n",
    "\"outputS3Path\": f'{s3_bucket_uri}/neptune-export',\n",
    "\"additionalParams\": {\n",
    "        \"neptune_ml\": {\n",
    "          \"version\": \"v2.0\",\n",
    "          \"targets\": [\n",
    "            {\n",
    "                \"edge\": [\"user\", \"user_live_cell\", \"cell\"],\n",
    "                \"type\" : \"link_prediction\",\n",
    "                \"split_rate\": [0.8, 0.1, 0.1]\n",
    "            }\n",
    "         ]\n",
    "        }\n",
    "      },\n",
    "\"jobSize\": \"xlarge\"}\n",
    "export_params"
   ]
  },
  {
   "cell_type": "code",
   "id": "b39de06e-072a-4ec8-93f1-09f9737d6958",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": [],
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "%%neptune_ml export start --export-url {neptune_ml.get_export_service_host()} --export-iam --wait --store-to export_results\n",
    "${export_params}"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "28005585-0c10-49d7-842b-84fc57694eeb",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Data processing/Preparation of graph data for Training"
   ]
  },
  {
   "cell_type": "code",
   "id": "45dc0584-9f3b-4d6d-be60-19603202cfa7",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "source": [
    "# The training_job_name can be set to a unique value below, otherwise one will be auto generated\n",
    "training_job_name=neptune_ml.get_training_job_name('link-prediction')\n",
    "\n",
    "processing_params = f\"\"\"\n",
    "--config-file-name training-data-configuration.json\n",
    "--job-id {training_job_name} \n",
    "--instance-type ml.r5.16xlarge\n",
    "--s3-input-uri {export_results['outputS3Uri']}\n",
    "--s3-processed-uri {str(s3_bucket_uri)}/preloading \"\"\""
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "1e4b3f37-4328-4c16-af3d-24a293330c86",
   "metadata": {},
   "source": [
    "%neptune_ml dataprocessing start --wait --store-to processing_results {processing_params}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38fc7d30-a76c-4cc3-a080-8607ef016ebb",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed1e2cd2-220a-4cf7-82d7-9ce655860cfe",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "<div style=\"background-color:#eeeeee; padding:20px; text-align:left; border-radius:10px; margin-top:10px; margin-bottom:10px; \"><b>Information</b>: Link prediction is a more computationally complex model than classification or regression </div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae94cf64-5ddd-4753-bb35-03e035505d61",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "training_params=f\"\"\"\n",
    "--job-id {training_job_name}\n",
    "--data-processing-id {training_job_name}\n",
    "--instance-type ml.g4dn.16xlarge\n",
    "--s3-output-uri {str(s3_bucket_uri)}/training\n",
    "--max-hpo-number 2\n",
    "--max-hpo-parallel 2 \"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c48124c-2e5d-4b52-8c07-b40103e1f3dd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "training_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93d920e4-e32a-42f5-9819-c711f847bea1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%neptune_ml training start --wait --store-to training_results {training_params}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "The model training above used all default parameters to minimize running time and cost but you can modify them to get a stronger model. For example, you can set --max-hpo-number 9 --max-hpo-parallel 3.\n",
    "\n",
    "You can also modify additional model and hyperparameter configurations by following the instructions [here](https://docs.aws.amazon.com/neptune/latest/userguide/machine-learning-customizing-hyperparams.html)\n",
    "\n",
    "For example you can set use-edge-features to True by modifying the `model-hpo-configuration.json` file. However using edge features will lead to errors in inductive inference in Notebook 3b."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "id": "6e49871f-7035-4067-819c-473d99d53dd4",
   "metadata": {
    "tags": [],
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Inference\n",
    "\n",
    "### Endpoint creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e14a77e-0abd-4122-85bd-5f44fbaab317",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "endpoint_params=f\"\"\"\n",
    "--id {training_job_name}\n",
    "--model-training-job-id {training_job_name}\"\"\"\n",
    "endpoint_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c81ec6c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%neptune_ml endpoint create --wait --store-to endpoint_results {endpoint_params}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0834412-cb3e-45e4-879d-ce6ec788f4db",
   "metadata": {
    "tags": []
   },
   "source": [
    "endpoint_transductive=endpoint_results['endpoint']['name']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7bd4bd6-59d8-4b58-85d2-1652f3744559",
   "metadata": {
    "tags": []
   },
   "source": [
    "### reminder\n",
    "\n",
    "- user_0 and cell_62000 / dropped ALL edges \n",
    "\n",
    "- user_1500 and cell_56500 / dropped ALL edges \n",
    "\n",
    "- user_4570 three edges of cell_10570"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8a16fad-e685-40e8-aa63-fdb1ae9d11ff",
   "metadata": {},
   "source": [
    "<img src=\"attachment:2a0d7696-5e11-42af-a56f-f7acc5d63572.png\" alt=\"image.png\" width=\"1000\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e8cec68-1415-407f-bd5e-74ce11129b25",
   "metadata": {
    "tags": []
   },
   "source": [
    "<div style=\"background-color:#eeeeee; padding:20px; text-align:left; border-radius:10px; margin-top:10px; margin-bottom:10px; \"><b>Experimentation1</b>: GNN is going to predict that user_0 is connected to cell_62000</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dee7675-7fe3-4541-8ac5-f20d2297ecaa",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.with(\"Neptune#ml.endpoint\",\"${endpoint_transductive}\")\n",
    ".with(\"Neptune#ml.limit\",10)\n",
    ".V(\"cell_56500\")\n",
    ".in(\"user_live_cell\").with(\"Neptune#ml.prediction\").hasLabel(\"user\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7959809-f20a-4a29-a454-a757b27360d8",
   "metadata": {
    "tags": []
   },
   "source": [
    "<div style=\"background-color:#eeeeee; padding:20px; text-align:left; border-radius:10px; margin-top:10px; margin-bottom:10px; \"><b>Experimentation3</b>: GNN is going to predict that user_1500 is connected to cell_63500</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d823a85-37ed-45ba-949e-f3298dff4c40",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%gremlin\n",
    "g.with(\"Neptune#ml.endpoint\",\"${endpoint_transductive}\")\n",
    ".with(\"Neptune#ml.limit\",5)\n",
    ".V(\"user_1500\")\n",
    ".out(\"user_live_cell\").with(\"Neptune#ml.prediction\").with(\"Neptune#ml.filterExistingEdges\").hasLabel(\"cell\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a0535ce-94d3-49ce-af3a-8495139e61a3",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "### end to end architecture for multi- scenario"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24859736-ed67-4278-97f3-2c0e354baf0d",
   "metadata": {},
   "source": [
    "<img src=\"attachment:f8b596d3-26cd-4127-aebb-132900d153db.png\" alt=\"image.png\" width=\"1000\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff163b50-96d8-4f1c-96d4-eec3a0a0d94c",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Note on GNN evaluation \n",
    "\n",
    "- HITS@10 provides a measure of how often the model suggests the correct item within the top 10 recommendation\n",
    "\n",
    "- While MR gives an overall indication of how close the correct item is to the top of the list on average.\n",
    "\n",
    "- Sagemaker evaluate the model on the validation and test set\n",
    "\n",
    "- Results on test set \n",
    "\n",
    "    * \"HITS at top 1 (HITS@1)\": 0.4010819758391616,\n",
    "    * \"HITS at top 10 (HITS@10)\": 0.9438622262173598,\n",
    "    * \"HITS at top 3 (HITS@3)\": 0.6301810418539388,\n",
    "    * \"mean rank (MR)\": 3.719502285632852,\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fa01006-2d37-4636-9350-0f5a8e116e4c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}